{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training with multiple GPUs\n",
    "\n",
    "Here we show how to run training from [Training neural network with DALI and JAX](jax-basic_example.ipynb) on multiple GPUs. We will use the same network and the same data pipeline. The only difference is that we will run it on multiple GPUs. To best understand the following content it is recommended to go through [Training neural network with DALI and JAX](jax-basic_example.ipynb) first.\n",
    "\n",
    "To learn how to run DALI iterator on multiple GPUs please refer to [Getting started with JAX and DALI section about multiple GPU support](jax-getting_started.ipynb#Multiple-GPUs). It explains how to run DALI iterator on multiple GPUs. Example below is building on top of that knowledge."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training with automatic parallelization\n",
    "\n",
    "In this section we want to spread the training across the GPUs with automatic parallelization mechanisms from JAX. To do that we need to define `sharding` that we want to apply to the computation. \n",
    "\n",
    "To learn more about sharding please refer to [JAX documentation section on distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PositionalSharding([[{GPU 0}]\n",
      "                    [{GPU 1}]])\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "from jax.sharding import PositionalSharding, Mesh\n",
    "from jax.experimental import mesh_utils\n",
    "\n",
    "\n",
    "mesh = mesh_utils.create_device_mesh((jax.device_count(), 1))\n",
    "sharding = PositionalSharding(mesh)\n",
    "\n",
    "print(sharding)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we create DALI iterator function. We base it on the function from [Training neural network with DALI and JAX](jax-basic_example.ipynb) example and add support for multiple GPUs with `sharding` and related arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:58.022258Z",
     "iopub.status.busy": "2023-07-28T07:43:58.021951Z",
     "iopub.status.idle": "2023-07-28T07:43:58.025225Z",
     "shell.execute_reply": "2023-07-28T07:43:58.024884Z"
    }
   },
   "outputs": [],
   "source": [
    "from nvidia.dali.plugin.jax import data_iterator\n",
    "import nvidia.dali.fn as fn\n",
    "import nvidia.dali.types as types\n",
    "\n",
    "\n",
    "image_size = 28\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "@data_iterator(\n",
    "    output_map=[\"images\", \"labels\"],\n",
    "    reader_name=\"mnist_caffe2_reader\",\n",
    "    sharding=sharding,\n",
    ")\n",
    "def mnist_training_iterator(data_path, num_shards, shard_id):\n",
    "    jpegs, labels = fn.readers.caffe2(\n",
    "        path=data_path,\n",
    "        random_shuffle=True,\n",
    "        name=\"mnist_caffe2_reader\",\n",
    "        num_shards=num_shards,\n",
    "        shard_id=shard_id,\n",
    "    )\n",
    "    images = fn.decoders.image(jpegs, device=\"mixed\", output_type=types.GRAY)\n",
    "    images = fn.crop_mirror_normalize(\n",
    "        images, dtype=types.FLOAT, std=[255.0], output_layout=\"CHW\"\n",
    "    )\n",
    "    images = fn.reshape(images, shape=[image_size * image_size])\n",
    "\n",
    "    labels = labels.gpu()\n",
    "    labels = fn.one_hot(labels, num_classes=num_classes)\n",
    "\n",
    "    return images, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For simplicity, in this tutorial we run the validation on a single GPU. We create appropriate DALI iterator function for validation data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@data_iterator(\n",
    "    output_map=[\"images\", \"labels\"], reader_name=\"mnist_caffe2_reader\"\n",
    ")\n",
    "def mnist_validation_iterator(data_path):\n",
    "    jpegs, labels = fn.readers.caffe2(\n",
    "        path=data_path, random_shuffle=False, name=\"mnist_caffe2_reader\"\n",
    "    )\n",
    "    images = fn.decoders.image(jpegs, device=\"mixed\", output_type=types.GRAY)\n",
    "    images = fn.crop_mirror_normalize(\n",
    "        images, dtype=types.FLOAT, std=[255.0], output_layout=\"CHW\"\n",
    "    )\n",
    "    images = fn.reshape(images, shape=[image_size * image_size])\n",
    "\n",
    "    labels = labels.gpu()\n",
    "\n",
    "    return images, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define some parameters for training and create iterator instances. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of batches in training iterator = 300\n",
      "Number of batches in validation iterator = 50\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "training_data_path = os.path.join(\n",
    "    os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/training/\"\n",
    ")\n",
    "validation_data_path = os.path.join(\n",
    "    os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/testing/\"\n",
    ")\n",
    "\n",
    "batch_size = 200\n",
    "num_epochs = 5\n",
    "\n",
    "\n",
    "training_iterator = mnist_training_iterator(\n",
    "    batch_size=batch_size, data_path=training_data_path\n",
    ")\n",
    "print(f\"Number of batches in training iterator = {len(training_iterator)}\")\n",
    "\n",
    "validation_iterator = mnist_validation_iterator(\n",
    "    batch_size=batch_size, data_path=validation_data_path\n",
    ")\n",
    "print(f\"Number of batches in validation iterator = {len(validation_iterator)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With all this setup ready we can start the actual training. We import model related utilities from [Training neural network with DALI and JAX](jax-basic_example.ipynb) example and use them to train the model.\n",
    "\n",
    "Each `batch` in the training loop contains `images` and `labels` sharded according to `sharding` argument.\n",
    "\n",
    "Note, how for validation we pull the model to one GPU. As said before, this was done for simplicity. In real world scenario, you could run validation on all GPUs and average the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 sec\n",
      "Test set accuracy 0.6739000082015991\n",
      "Epoch 1 sec\n",
      "Test set accuracy 0.7844000458717346\n",
      "Epoch 2 sec\n",
      "Test set accuracy 0.8244000673294067\n",
      "Epoch 3 sec\n",
      "Test set accuracy 0.8455000519752502\n",
      "Epoch 4 sec\n",
      "Test set accuracy 0.860200047492981\n"
     ]
    }
   ],
   "source": [
    "from model import init_model, accuracy\n",
    "from model import update\n",
    "\n",
    "model = init_model()\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for it, batch in enumerate(training_iterator):\n",
    "        model = update(model, batch)\n",
    "\n",
    "    model_on_one_device = jax.tree_map(\n",
    "        lambda x: jax.device_put(x, jax.devices()[0]), model\n",
    "    )\n",
    "    test_acc = accuracy(model_on_one_device, validation_iterator)\n",
    "\n",
    "    print(f\"Epoch {epoch} sec\")\n",
    "    print(f\"Test set accuracy {test_acc}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training with `pmapped` iterator\n",
    "\n",
    "JAX offers another mechanism to spread computation across multiple devices: `pmap` function. DALI can support this way of parallelization as well. \n",
    "\n",
    "To learn more about `pmap` look into [JAX documentation](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#).\n",
    "\n",
    "In DALI, to configure the iterator in a way compatible with `pmapped` functions we pass `devices` argument instead of `sharding`. Here we use all available GPUs. Iterator will return `batch` that is sharded across all GPUs.\n",
    "\n",
    "As with `sharding`, under the hood iterator will create multiple instances of DALI pipeline and each instance will be assigned to one GPU. When the outputs are requested, DALI will synchronize the instances and return the results as a single `batch`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "@data_iterator(\n",
    "    output_map=[\"images\", \"labels\"],\n",
    "    reader_name=\"mnist_caffe2_reader\",\n",
    "    devices=jax.devices(),\n",
    ")\n",
    "def mnist_training_iterator(data_path, num_shards, shard_id):\n",
    "    jpegs, labels = fn.readers.caffe2(\n",
    "        path=data_path,\n",
    "        random_shuffle=True,\n",
    "        name=\"mnist_caffe2_reader\",\n",
    "        num_shards=num_shards,\n",
    "        shard_id=shard_id,\n",
    "    )\n",
    "    images = fn.decoders.image(jpegs, device=\"mixed\", output_type=types.GRAY)\n",
    "    images = fn.crop_mirror_normalize(\n",
    "        images, dtype=types.FLOAT, std=[255.0], output_layout=\"CHW\"\n",
    "    )\n",
    "    images = fn.reshape(images, shape=[image_size * image_size])\n",
    "\n",
    "    labels = labels.gpu()\n",
    "    labels = fn.one_hot(labels, num_classes=num_classes)\n",
    "\n",
    "    return images, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We create an iterator instance the same way as before:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating training iterator\n",
      "Number of batches in training iterator = 300\n"
     ]
    }
   ],
   "source": [
    "print(\"Creating training iterator\")\n",
    "training_iterator = mnist_training_iterator(\n",
    "    batch_size=batch_size, data_path=training_data_path\n",
    ")\n",
    "\n",
    "print(f\"Number of batches in training iterator = {len(training_iterator)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For validation, we will use the same iterator as before. Since we are running it on single GPU, we don't need to change anything. We can again pull the model to one GPU and run the validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of batches in validation iterator = 50\n"
     ]
    }
   ],
   "source": [
    "print(f\"Number of batches in validation iterator = {len(validation_iterator)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the model to be compatible with pmap-style multiple GPU training we need to replicate it. If you want to learn more about training on multiple GPUs with `pmap` you can look into [Parallel Evaluation in JAX](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html) from the JAX documentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:58.245168Z",
     "iopub.status.busy": "2023-07-28T07:43:58.244604Z",
     "iopub.status.idle": "2023-07-28T07:43:58.292113Z",
     "shell.execute_reply": "2023-07-28T07:43:58.291012Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "from model import init_model, accuracy\n",
    "\n",
    "\n",
    "model = init_model()\n",
    "model = jax.tree_map(lambda x: jnp.array([x] * jax.device_count()), model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For multigpu training we import `update_parallel` function. It is the same as the `update` function with added gradients synchronization across the devices. This will ensure that replicas of the model from different devices remain the same. \n",
    "\n",
    "Since we want to run validation on a single GPU, we extract only one replica of the model and pass it to `accuracy` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:58.294940Z",
     "iopub.status.busy": "2023-07-28T07:43:58.294735Z",
     "iopub.status.idle": "2023-07-28T07:44:12.533682Z",
     "shell.execute_reply": "2023-07-28T07:44:12.533238Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 sec\n",
      "Test set accuracy 0.6885000467300415\n",
      "Epoch 1 sec\n",
      "Test set accuracy 0.7829000353813171\n",
      "Epoch 2 sec\n",
      "Test set accuracy 0.8222000598907471\n",
      "Epoch 3 sec\n",
      "Test set accuracy 0.8438000679016113\n",
      "Epoch 4 sec\n",
      "Test set accuracy 0.8580000400543213\n"
     ]
    }
   ],
   "source": [
    "from model import update_parallel\n",
    "\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for it, batch in enumerate(training_iterator):\n",
    "        model = update_parallel(model, batch)\n",
    "\n",
    "    test_acc = accuracy(\n",
    "        jax.tree_map(lambda x: x[0], model), validation_iterator\n",
    "    )\n",
    "\n",
    "    print(f\"Epoch {epoch} sec\")\n",
    "    print(f\"Test set accuracy {test_acc}\")"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
